# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "astropy",
#     "disc-limo",
#     "discminer",
#     "matplotlib",
#     "numpy",
# ]
#
# [tool.uv.sources]
# disc-limo = { git = "https://github.com/TomHilder/disc_limo.git" }
# discminer = { git = "https://github.com/andizq/discminer.git" }
# ///

from copy import deepcopy
from pathlib import Path

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from disc_limo import (
    convert_weights_to_channels,
    get_design_matrices,
    get_posterior_samples,
)
from discminer.core import Data

# === Constants (NOTE: Both files get changed if you want to do a different source)
CUBE_FILE_THAT_WAS_FIT = Path(
    "J1842_12CO.contsub_beam0.15_100ms_3sigma.clean.image_clipped_downsamp_2pix.fits"
)
SAMPLES_FILE = Path("J1842_100pix_125modes_100samples.npz")
DISTANCE = 151 * u.pc
N_SAMPLES = 2  # 100 max
PLOTS = True


# === Sample cube generation
def main():
    # We can load the sample Fourier weights from the posterior, which produces
    # an array of shape (N_SAMPLES, n_channels, n_fourier_modes).
    samples = np.load(SAMPLES_FILE)["samples"]
    # print(samples.shape)

    # In order to get sample cubes we need to evaluate the model on a pixel grid. We can
    # choose literally anything for this but let's do it on the grid of the data we fit.
    # NOTE: You can totally not do this if you like, but remember that it will take
    # longer to run if you pick a really dense pixel grid, and by the time you have many
    # samples you could take up A LOT of disk space (and a lot of memory!). You will
    # have to change n_eval in get_design_matrices and the code following that as well
    # in order to do this.

    # Getting the shape of the data
    with fits.open(CUBE_FILE_THAT_WAS_FIT) as file:
        n_ch_data, n_x_data, n_y_data = file[0].shape
        data = file[0].data
        # plt.imshow(data[45, :, :], vmin=0, vmax=np.nanmax(data))
        # plt.show()

    # Getting the number of pixels that were included in the fit
    n_pix_fit = int(SAMPLES_FILE.name.split("pix")[0].split("_")[-1])
    # Number of Fourier modes
    n_fourier_fit = int(SAMPLES_FILE.name.split("modes")[0].split("_")[-1])
    # Check that data and fit contain the same number of channels
    if not n_ch_data == samples.shape[1]:
        raise Exception(
            "The fit and data don't have the same number of channels. \
            Are you sure your fit was performed on the cube provided?"
        )

    # Generate design matrices needed to recover image/cube samples
    # These are just matrices full of Sines and Cosines evaluated on a pixel grid
    fourier_design, true_design = get_design_matrices(
        filename=CUBE_FILE_THAT_WAS_FIT,
        n_pix=n_pix_fit,
        n_fourier=n_fourier_fit,
        n_eval=n_pix_fit,
    )
    # Let's get our images
    sample_true_cubes = []
    sample_conv_cubes = []
    for i in range(N_SAMPLES):
        print(f"Converting sample {i} to images")
        # Get sample weights
        sample_weights = samples[i, :, :]
        # Convert weights to cubes
        cube_true = convert_weights_to_channels(
            weights=sample_weights,
            design_matrix=fourier_design,
        ).astype(np.float32)
        cube_conv = convert_weights_to_channels(
            weights=sample_weights,
            design_matrix=true_design,
        ).astype(np.float32)
        # Add to list
        sample_true_cubes.append(cube_true)
        sample_conv_cubes.append(cube_conv)
    # Convert to arrays of shape (N_SAMPLES, n_channels, n_pix_fit, n_pix_fit)
    sample_true_cubes = np.array(sample_true_cubes)
    sample_conv_cubes = np.array(sample_conv_cubes)

    if PLOTS:
        # Plot one sample as compared to data
        _, ax = plt.subplots(
            1,
            3,
            layout="compressed",
            figsize=[12, 4],
            dpi=100,
            sharex=True,
            sharey=True,
        )
        # Plot config
        vmax = np.nanmax(data)
        cmap = "RdBu"
        ch_i = 45
        sample_i = 0
        imshow_kwargs = dict(vmin=-vmax, vmax=vmax, cmap=cmap)
        # Clip data to same size as sample NOTE: only works if n_eval=n_pix_fit
        centre_ind_data = data.shape[-1] // 2
        ext_plot_data = n_pix_fit // 2
        bottom_i = centre_ind_data - ext_plot_data
        top_i = centre_ind_data + ext_plot_data
        data_clipped = data[:, bottom_i:top_i, bottom_i:top_i]
        # Do plotting
        ax[0].imshow(data_clipped[ch_i, :, :], **imshow_kwargs)
        ax[0].set_title("Data")
        ax[1].imshow(sample_conv_cubes[sample_i, ch_i, :, :], **imshow_kwargs)
        ax[1].set_title("Sample (convolved w/ beam)")
        ax[2].imshow(sample_true_cubes[sample_i, ch_i, :, :], **imshow_kwargs)
        ax[2].set_title("Sample (true intensity/not convolved)")
        # Turn axis ticks off
        for i in range(3):
            ax[0].set_xticks([])
            ax[0].set_yticks([])
        plt.savefig("data_compared_samples.png", bbox_inches="tight")
        plt.show()

        # Plot min(all, 10) samples
        n_samples_plot = min(N_SAMPLES, 10)
        _, ax = plt.subplots(
            2,
            n_samples_plot,
            layout="compressed",
            figsize=[4 * n_samples_plot, 8],
            dpi=100,
            sharex=True,
            sharey=True,
        )
        # Do plotting
        ax[0, 0].set_ylabel("Samples (convolved w/ beam)")
        ax[1, 0].set_ylabel("Samples (true intensity/not convolved)")
        for i in range(n_samples_plot):
            ax[0, i].imshow(sample_conv_cubes[i, ch_i, :, :], **imshow_kwargs)
            ax[1, i].imshow(sample_true_cubes[i, ch_i, :, :], **imshow_kwargs)
            # Turn axis ticks off
            for j in range(2):
                ax[j, i].set_xticks([])
                ax[j, i].set_yticks([])
        plt.savefig("samples.png", bbox_inches="tight")
        plt.show()

    # Use discminer to create cube clipped down to size of fit, then replace data

    # Read with discminer, clip to fit size
    datacube = Data(CUBE_FILE_THAT_WAS_FIT, DISTANCE)
    datacube.clip(npix=n_pix_fit // 2, overwrite=True)
    clipped_file = CUBE_FILE_THAT_WAS_FIT.with_name(
        f"{CUBE_FILE_THAT_WAS_FIT.name.split('.fits')[0]}_clipped.fits"
    )
    # Open clipped file to steal HDUList object
    with fits.open(clipped_file) as file:
        hdulist_copy = deepcopy(file)

    # Function for replacing the data
    def replace_data_hdu(hdulist, data):
        hdulist_temp = deepcopy(hdulist)
        assert hdulist_temp[0].data.shape == data.shape
        hdulist_temp[0].data = data
        return hdulist_temp

    def write_hdulist(hdulist, file):
        hdulist.writeto(file, overwrite=True)

    # Iterate over samples and save new fits files
    for i in range(N_SAMPLES):
        print(f"Saving sample {i} to fits")
        # replace data in hdulist object
        hdu_sample_true = replace_data_hdu(hdulist_copy, sample_true_cubes[i, :, :, :])
        hdu_sample_conv = replace_data_hdu(hdulist_copy, sample_conv_cubes[i, :, :, :])
        # write to new files
        write_hdulist(
            hdu_sample_true, CUBE_FILE_THAT_WAS_FIT.with_name(f"sample_{i}_true.fits")
        )
        write_hdulist(
            hdu_sample_conv, CUBE_FILE_THAT_WAS_FIT.with_name(f"sample_{i}_conv.fits")
        )

    # hdu_sample_mu = replace_data_hdu(hdulist_copy, cube_mu[:, :, :])
    # write_hdulist(hdu_sample_mu, CUBE_FILE_THAT_WAS_FIT.with_name(f"mu.fits"))

    # delete the clipped file
    clipped_file.unlink()


# ===
if __name__ == "__main__":
    main()
